{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from packaging.version import parse\n",
    "\n",
    "from fastai.basics import *\n",
    "from fastai.vision.core import *\n",
    "from fastai.vision.data import *\n",
    "from fastai.vision.augment import *\n",
    "from fastai.vision import models\n",
    "\n",
    "import torchvision\n",
    "try: import timm\n",
    "except ModuleNotFoundError: pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp vision.learner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Vision learner\n",
    "\n",
    "> All the functions necessary to build `Learner` suitable for transfer learning in computer vision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The most important functions of this module are `vision_learner` and `unet_learner`. They will help you define a `Learner` using a pretrained model. See the [vision tutorial](23_tutorial.vision.ipynb) for examples of use."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cut a pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _is_pool_type(l): return re.search(r'Pool[123]d$', l.__class__.__name__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "m = nn.Sequential(nn.AdaptiveAvgPool2d(5), nn.Linear(2,3), nn.Conv2d(2,3,1), nn.MaxPool3d(5))\n",
    "test_eq([bool(_is_pool_type(m_)) for m_ in m.children()], [True,False,False,True])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "By default, the fastai library cuts a pretrained model at the pooling layer. This function helps detecting it. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def has_pool_type(m):\n",
    "    \"Return `True` if `m` is a pooling layer or has one in its children\"\n",
    "    if _is_pool_type(m): return True\n",
    "    for l in m.children():\n",
    "        if has_pool_type(l): return True\n",
    "    return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = nn.Sequential(nn.AdaptiveAvgPool2d(5), nn.Linear(2,3), nn.Conv2d(2,3,1), nn.MaxPool3d(5))\n",
    "assert has_pool_type(m)\n",
    "test_eq([has_pool_type(m_) for m_ in m.children()], [True,False,False,True])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _get_first_layer(m):\n",
    "    \"Access first layer of a model\"\n",
    "    c,p,n = m,None,None  # child, parent, name\n",
    "    for n in next(m.named_parameters())[0].split('.')[:-1]:\n",
    "        p,c=c,getattr(c,n)\n",
    "    return c,p,n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _load_pretrained_weights(new_layer, previous_layer):\n",
    "    \"Load pretrained weights based on number of input channels\"\n",
    "    n_in = getattr(new_layer, 'in_channels')\n",
    "    if n_in==1:\n",
    "        # we take the sum\n",
    "        new_layer.weight.data = previous_layer.weight.data.sum(dim=1, keepdim=True)\n",
    "    elif n_in==2:\n",
    "        # we take first 2 channels + 50%\n",
    "        new_layer.weight.data = previous_layer.weight.data[:,:2] * 1.5\n",
    "    else:\n",
    "        # keep 3 channels weights and set others to null\n",
    "        new_layer.weight.data[:,:3] = previous_layer.weight.data\n",
    "        new_layer.weight.data[:,3:].zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _update_first_layer(model, n_in, pretrained):\n",
    "    \"Change first layer based on number of input channels\"\n",
    "    if n_in == 3: return\n",
    "    first_layer, parent, name = _get_first_layer(model)\n",
    "    assert isinstance(first_layer, nn.Conv2d), f'Change of input channels only supported with Conv2d, found {first_layer.__class__.__name__}'\n",
    "    assert getattr(first_layer, 'in_channels') == 3, f'Unexpected number of input channels, found {getattr(first_layer, \"in_channels\")} while expecting 3'\n",
    "    params = {attr:getattr(first_layer, attr) for attr in 'out_channels kernel_size stride padding dilation groups padding_mode'.split()}\n",
    "    params['bias'] = getattr(first_layer, 'bias') is not None\n",
    "    params['in_channels'] = n_in\n",
    "    new_layer = nn.Conv2d(**params)\n",
    "    if pretrained:\n",
    "        _load_pretrained_weights(new_layer, first_layer)\n",
    "    setattr(parent, name, new_layer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def cut_model(model, cut):\n",
    "    \"Cut an instantiated model\"\n",
    "    if   isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])\n",
    "    elif callable(cut): return cut(model)\n",
    "    raise NameError(\"cut must be either integer or a function\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def create_body(model, n_in=3, pretrained=True, cut=None):\n",
    "    \"Cut off the body of a typically pretrained `arch` as determined by `cut`\"\n",
    "    _update_first_layer(model, n_in, pretrained)\n",
    "    if cut is None:\n",
    "        ll = list(enumerate(model.children()))\n",
    "        cut = next(i for i,o in reversed(ll) if has_pool_type(o))\n",
    "    return cut_model(model, cut)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`cut` can either be an integer, in which case we cut the model at the corresponding layer, or a function, in which case, this function returns `cut(model)`. It defaults to the first layer that contains some pooling otherwise."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tst(): return nn.Sequential(nn.Conv2d(3,5,3), nn.BatchNorm2d(5), nn.AvgPool2d(1), nn.Linear(3,4))\n",
    "m = create_body(tst())\n",
    "test_eq(len(m), 2)\n",
    "\n",
    "m = create_body(tst(), cut=3)\n",
    "test_eq(len(m), 3)\n",
    "\n",
    "m = create_body(tst(), cut=noop)\n",
    "test_eq(len(m), 4)\n",
    "\n",
    "for n in range(1,5):    \n",
    "    m = create_body(tst(), n_in=n)\n",
    "    test_eq(_get_first_layer(m)[0].in_channels, n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Head and model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def create_head(nf, n_out, lin_ftrs=None, ps=0.5, pool=True, concat_pool=True, first_bn=True, bn_final=False,\n",
    "                lin_first=False, y_range=None):\n",
    "    \"Model head that takes `nf` features, runs through `lin_ftrs`, and out `n_out` classes.\"\n",
    "    if pool and concat_pool: nf *= 2\n",
    "    lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out]\n",
    "    bns = [first_bn] + [True]*len(lin_ftrs[1:])\n",
    "    ps = L(ps)\n",
    "    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps\n",
    "    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]\n",
    "    layers = []\n",
    "    if pool:\n",
    "        pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)\n",
    "        layers += [pool, Flatten()]\n",
    "    if lin_first: layers.append(nn.Dropout(ps.pop(0)))\n",
    "    for ni,no,bn,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], bns, ps, actns):\n",
    "        layers += LinBnDrop(ni, no, bn=bn, p=p, act=actn, lin_first=lin_first)\n",
    "    if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))\n",
    "    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))\n",
    "    if y_range is not None: layers.append(SigmoidRange(*y_range))\n",
    "    return nn.Sequential(*layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The head begins with fastai's `AdaptiveConcatPool2d` if `concat_pool=True` otherwise, it uses traditional average pooling. Then it uses a `Flatten` layer before going on blocks of `BatchNorm`, `Dropout` and `Linear` layers (if `lin_first=True`, those are `Linear`, `BatchNorm`, `Dropout`).\n",
    "\n",
    "Those blocks start at `nf`, then every element of `lin_ftrs` (defaults to `[512]`) and end at `n_out`. `ps` is a list of probabilities used for the dropouts (if you only pass 1, it will use half the value then that value as many times as necessary).\n",
    "\n",
    "If `first_bn=True`, a `BatchNorm` added just after the pooling operations. If `bn_final=True`, a final `BatchNorm` layer is added. If `y_range` is passed, the function adds a `SigmoidRange` to that range."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): AdaptiveConcatPool2d(\n",
       "    (ap): AdaptiveAvgPool2d(output_size=1)\n",
       "    (mp): AdaptiveMaxPool2d(output_size=1)\n",
       "  )\n",
       "  (1): fastai.layers.Flatten(full=False)\n",
       "  (2): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (3): Dropout(p=0.25, inplace=False)\n",
       "  (4): Linear(in_features=10, out_features=512, bias=False)\n",
       "  (5): ReLU(inplace=True)\n",
       "  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (7): Dropout(p=0.5, inplace=False)\n",
       "  (8): Linear(in_features=512, out_features=10, bias=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tst = create_head(5, 10)\n",
    "tst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 9)\n",
    "assert isinstance(mods[2], nn.BatchNorm1d)\n",
    "assert isinstance(mods[-1], nn.Linear)\n",
    "\n",
    "tst = create_head(5, 10, lin_first=True)\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 8)\n",
    "assert isinstance(mods[2], nn.Dropout)\n",
    "\n",
    "tst = create_head(5, 10, first_bn=False)\n",
    "mods = list(tst.children())\n",
    "test_eq(len(mods), 8)\n",
    "assert isinstance(mods[2], nn.Dropout)\n",
    "\n",
    "tst = create_head(5, 10, concat_pool=True)\n",
    "modes = list(tst.children())\n",
    "test_eq(modes[4].in_features, 10)\n",
    "\n",
    "tst = create_head(5, 10, concat_pool=False)\n",
    "modes = list(tst.children())\n",
    "test_eq(modes[4].in_features, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from fastai.callback.hook import num_features_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#TODO: refactor, i.e. something like this?\n",
    "# class ModelSplitter():\n",
    "#     def __init__(self, idx): self.idx = idx\n",
    "#     def split(self, m): return L(m[:self.idx], m[self.idx:]).map(params)\n",
    "#     def __call__(self,): return {'cut':self.idx, 'split':self.split}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def default_split(m):\n",
    "    \"Default split of a model between body and head\"\n",
    "    return L(m[0], m[1:]).map(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To do transfer learning, you need to pass a `splitter` to `Learner`. This should be a function taking the model and returning a collection of parameter groups, e.g. a list of list of parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _xresnet_split(m): return L(m[0][:3], m[0][3:], m[1:]).map(params)\n",
    "def  _resnet_split(m): return L(m[0][:6], m[0][6:], m[1:]).map(params)\n",
    "def _squeezenet_split(m:nn.Module): return L(m[0][0][:5], m[0][0][5:], m[1:]).map(params)\n",
    "def _densenet_split(m:nn.Module): return L(m[0][0][:7],m[0][0][7:], m[1:]).map(params)\n",
    "def _vgg_split(m:nn.Module): return L(m[0][0][:22], m[0][0][22:], m[1:]).map(params)\n",
    "def _alexnet_split(m:nn.Module): return L(m[0][0][:6], m[0][0][6:], m[1:]).map(params)\n",
    "\n",
    "_default_meta    = {'cut':None, 'split':default_split}\n",
    "_xresnet_meta    = {'cut':-4, 'split':_xresnet_split, 'stats':imagenet_stats}\n",
    "_resnet_meta     = {'cut':-2, 'split':_resnet_split, 'stats':imagenet_stats, 'weights':'DEFAULT'}\n",
    "_squeezenet_meta = {'cut':-1, 'split': _squeezenet_split, 'stats':imagenet_stats, 'weights':'DEFAULT'}\n",
    "_densenet_meta   = {'cut':-1, 'split':_densenet_split, 'stats':imagenet_stats, 'weights':'DEFAULT'}\n",
    "_vgg_meta        = {'cut':-2, 'split':_vgg_split, 'stats':imagenet_stats, 'weights':'DEFAULT'}\n",
    "_alexnet_meta    = {'cut':-2, 'split':_alexnet_split, 'stats':imagenet_stats, 'weights':'DEFAULT'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "model_meta = {\n",
    "    models.xresnet.xresnet18 :{**_xresnet_meta}, models.xresnet.xresnet34: {**_xresnet_meta},\n",
    "    models.xresnet.xresnet50 :{**_xresnet_meta}, models.xresnet.xresnet101:{**_xresnet_meta},\n",
    "    models.xresnet.xresnet152:{**_xresnet_meta},\n",
    "\n",
    "    models.resnet18 :{**_resnet_meta}, models.resnet34: {**_resnet_meta},\n",
    "    models.resnet50 :{**_resnet_meta}, models.resnet101:{**_resnet_meta},\n",
    "    models.resnet152:{**_resnet_meta},\n",
    "\n",
    "    models.squeezenet1_0:{**_squeezenet_meta},\n",
    "    models.squeezenet1_1:{**_squeezenet_meta},\n",
    "\n",
    "    models.densenet121:{**_densenet_meta}, models.densenet169:{**_densenet_meta},\n",
    "    models.densenet201:{**_densenet_meta}, models.densenet161:{**_densenet_meta},\n",
    "    models.vgg11_bn:{**_vgg_meta}, models.vgg13_bn:{**_vgg_meta}, models.vgg16_bn:{**_vgg_meta}, models.vgg19_bn:{**_vgg_meta},\n",
    "    models.alexnet:{**_alexnet_meta}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def add_head(body, nf, n_out, init=nn.init.kaiming_normal_, head=None, concat_pool=True, pool=True,\n",
    "                lin_ftrs=None, ps=0.5, first_bn=True, bn_final=False, lin_first=False, y_range=None):\n",
    "    \"Add a head to a vision body\"\n",
    "    if head is None:\n",
    "        head = create_head(nf, n_out, concat_pool=concat_pool, pool=pool,\n",
    "                           lin_ftrs=lin_ftrs, ps=ps, first_bn=first_bn, bn_final=bn_final, lin_first=lin_first, y_range=y_range)\n",
    "    model = nn.Sequential(body, head)\n",
    "    if init is not None: apply_init(model[1], init)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def create_vision_model(arch, n_out, pretrained=True, weights=None, cut=None, n_in=3, init=nn.init.kaiming_normal_, custom_head=None,\n",
    "                        concat_pool=True, pool=True, lin_ftrs=None, ps=0.5, first_bn=True, bn_final=False, lin_first=False, y_range=None):\n",
    "    \"Create custom vision architecture\"\n",
    "    meta = model_meta.get(arch, _default_meta)\n",
    "    if parse(torchvision.__version__) >= parse('0.13') and 'weights' in meta:\n",
    "        if weights is not None and not pretrained:\n",
    "            warn(f'{pretrained=} but `weights` are set {weights=}. To randomly initialize set `pretrained=False` & `weights=None`')\n",
    "        model = arch(weights=meta['weights'] if (weights is None and pretrained) else weights)\n",
    "    else:\n",
    "        model = arch(pretrained=pretrained)\n",
    "    body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut']))\n",
    "    nf = num_features_model(nn.Sequential(*body.children())) if custom_head is None else None\n",
    "    return add_head(body, nf, n_out, init=init, head=custom_head, concat_pool=concat_pool, pool=pool,\n",
    "                    lin_ftrs=lin_ftrs, ps=ps, first_bn=first_bn, bn_final=bn_final, lin_first=lin_first, y_range=y_range)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/vision/learner.py#L163){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### create_vision_model\n",
       "\n",
       ">      create_vision_model (arch, n_out, pretrained=True, weights=None,\n",
       ">                           cut=None, n_in=3, init=<function kaiming_normal_>,\n",
       ">                           custom_head=None, concat_pool=True, pool=True,\n",
       ">                           lin_ftrs=None, ps=0.5, first_bn=True,\n",
       ">                           bn_final=False, lin_first=False, y_range=None)\n",
       "\n",
       "Create custom vision architecture"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/vision/learner.py#L163){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### create_vision_model\n",
       "\n",
       ">      create_vision_model (arch, n_out, pretrained=True, weights=None,\n",
       ">                           cut=None, n_in=3, init=<function kaiming_normal_>,\n",
       ">                           custom_head=None, concat_pool=True, pool=True,\n",
       ">                           lin_ftrs=None, ps=0.5, first_bn=True,\n",
       ">                           bn_final=False, lin_first=False, y_range=None)\n",
       "\n",
       "Create custom vision architecture"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(create_vision_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is cut according to `cut` and it may be `pretrained`, in which case, the proper set of weights is downloaded then loaded. `init` is applied to the head of the model, which is either created by `create_head` (with `lin_ftrs`, `ps`, `concat_pool`, `bn_final`, `lin_first` and `y_range`) or is `custom_head`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = create_vision_model(models.resnet18, 10, True)\n",
    "tst = create_vision_model(models.resnet18, 10, True, n_in=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TimmBody(nn.Module):\n",
    "    def __init__(self, model, pretrained:bool=True, cut=None, n_in:int=3):\n",
    "        super().__init__()\n",
    "        self.needs_pool = model.default_cfg.get('pool_size', None) is not None\n",
    "        self.model = model if cut is None else cut_model(model, cut)\n",
    "    \n",
    "    def forward(self,x): return self.model.forward_features(x) if self.needs_pool else self.model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def create_timm_model(arch, n_out, cut=None, pretrained=True, n_in=3, init=nn.init.kaiming_normal_, custom_head=None,\n",
    "                     concat_pool=True, pool=True, lin_ftrs=None, ps=0.5, first_bn=True, bn_final=False, lin_first=False, y_range=None, **kwargs):\n",
    "    \"Create custom architecture using `arch`, `n_in` and `n_out` from the `timm` library\"\n",
    "    model = timm.create_model(arch, pretrained=pretrained, num_classes=0, in_chans=n_in, **kwargs)\n",
    "    body = TimmBody(model, pretrained, None, n_in)\n",
    "    nf = body.model.num_features\n",
    "    res = add_head(body, nf, n_out, init=init, head=custom_head, concat_pool=concat_pool, pool=body.needs_pool,\n",
    "                   lin_ftrs=lin_ftrs, ps=ps, first_bn=first_bn, bn_final=bn_final, lin_first=lin_first, y_range=y_range)\n",
    "    return res,model.default_cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make sure that timm models can be scripted:\n",
    "tst, _ = create_timm_model('resnet34', 1)\n",
    "scripted = torch.jit.script(tst)\n",
    "assert scripted, \"model could not be converted to TorchScript\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `Learner` convenience functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _add_norm(dls, meta, pretrained, n_in=3):\n",
    "    if not pretrained: return\n",
    "    stats = meta.get('stats')\n",
    "    if stats is None: return\n",
    "    if n_in != len(stats[0]): return\n",
    "    if not dls.after_batch.fs.filter(risinstance(Normalize)):\n",
    "        dls.add_tfms([Normalize.from_stats(*stats)],'after_batch')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "path = untar_data(URLs.PETS)\n",
    "dls = ImageDataLoaders.from_name_re(path, get_image_files(path/\"images\"), r'^(.*)_\\d+.jpg$', item_tfms=Resize(224))\n",
    "for _ in range(5): _add_norm(dls, model_meta[models.resnet34], True)\n",
    "test_eq(len(dls.after_batch.fs), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def _timm_norm(dls, cfg, pretrained, n_in=3):\n",
    "    if not pretrained: return\n",
    "    if n_in != len(cfg['mean']): return\n",
    "    if not dls.after_batch.fs.filter(risinstance(Normalize)):\n",
    "        tfm = Normalize.from_stats(cfg['mean'],cfg['std'])\n",
    "        dls.add_tfms([tfm],'after_batch')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@delegates(create_vision_model)\n",
    "def vision_learner(dls, arch, normalize=True, n_out=None, pretrained=True, weights=None,\n",
    "        # learner args\n",
    "        loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,\n",
    "        model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95),\n",
    "        # model & head args\n",
    "        cut=None, init=nn.init.kaiming_normal_, custom_head=None, concat_pool=True, pool=True,\n",
    "        lin_ftrs=None, ps=0.5, first_bn=True, bn_final=False, lin_first=False, y_range=None, **kwargs):\n",
    "    \"Build a vision learner from `dls` and `arch`\"\n",
    "    if n_out is None: n_out = get_c(dls)\n",
    "    assert n_out, \"`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`\"\n",
    "    meta = model_meta.get(arch, _default_meta)\n",
    "    model_args = dict(init=init, custom_head=custom_head, concat_pool=concat_pool, pool=pool, lin_ftrs=lin_ftrs, ps=ps,\n",
    "                      first_bn=first_bn, bn_final=bn_final, lin_first=lin_first, y_range=y_range, **kwargs)\n",
    "    n_in = kwargs['n_in'] if 'n_in' in kwargs else 3\n",
    "    if isinstance(arch, str):\n",
    "        model,cfg = create_timm_model(arch, n_out, default_split, pretrained, **model_args)\n",
    "        if normalize: _timm_norm(dls, cfg, pretrained, n_in)\n",
    "    else:\n",
    "        if normalize: _add_norm(dls, meta, pretrained, n_in)\n",
    "        model = create_vision_model(arch, n_out, pretrained=pretrained, weights=weights, **model_args)\n",
    "\n",
    "    splitter = ifnone(splitter, meta['split'])\n",
    "    learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,\n",
    "                   metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn, moms=moms)\n",
    "    if pretrained: learn.freeze()\n",
    "    # keep track of args for loggers\n",
    "    store_attr('arch,normalize,n_out,pretrained', self=learn, **kwargs)\n",
    "    return learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is built from `arch` using the number of final activations inferred from `dls` if possible (otherwise pass a value to `n_out`). It might be `pretrained` and the architecture is cut and split using the default metadata of the model architecture (this can be customized by passing a `cut` or a `splitter`).\n",
    "\n",
    "If `normalize` and `pretrained` are `True`, this function adds a `Normalization` transform to the `dls` (if there is not already one) using the statistics of the pretrained model. That way, you won't ever forget to normalize your data in transfer learning.\n",
    "\n",
    "All other arguments are passed to `Learner`.\n",
    "\n",
    "Starting with version 0.13, TorchVision supports [multiple pretrained weights](https://pytorch.org/vision/stable/models.html#initializing-pre-trained-models) for the same model architecture. The <code>vision_learner</code> default of `pretrained=True, weights=None` will use the architecture's default weights, which are currently IMAGENET1K_V2. If you are using an older version of TorchVision or creating a [timm](https://huggingface.co/docs/timm/index) model, setting `weights` will have no effect.\n",
    "\n",
    "```python\n",
    "from torchvision.models import ResNet50_Weights\n",
    "\n",
    "# Legacy weights with accuracy 76.130%\n",
    "vision_learner(models.resnet50, pretrained=True, weights=ResNet50_Weights.IMAGENET1K_V1, ...)\n",
    "\n",
    "# New weights with accuracy 80.858%. Strings are also supported.\n",
    "vision_learner(models.resnet50, pretrained=True, weights='IMAGENET1K_V2', ...)\n",
    "\n",
    "# Best available weights (currently an alias for IMAGENET1K_V2).\n",
    "# Default weights if vision_learner weights isn't set.\n",
    "vision_learner(models.resnet50, pretrained=True, weights=ResNet50_Weights.DEFAULT, ...)\n",
    "\n",
    "# No weights - random initialization\n",
    "vision_learner(models.resnet50, pretrained=False, weights=None, ...)\n",
    "```\n",
    "\n",
    "The example above shows how to use the new TorchVision 0.13 multi-weight api with <code>vision_learner</code>."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.PETS)\n",
    "fnames = get_image_files(path/\"images\")\n",
    "pat = r'^(.*)_\\d+.jpg$'\n",
    "dls = ImageDataLoaders.from_name_re(path, fnames, pat, item_tfms=Resize(224))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = vision_learner(dls, models.resnet18, loss_func=CrossEntropyLossFlat(), ps=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "if parse(torchvision.__version__) >= parse('0.13'):\n",
    "    from torchvision.models import ResNet34_Weights\n",
    "    weights = ResNet34_Weights.IMAGENET1K_V1\n",
    "else:\n",
    "    weights = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "learn = vision_learner(dls, models.resnet34, weights=weights, loss_func=CrossEntropyLossFlat(), ps=0.25, concat_pool=False)\n",
    "test_ne(learn.cbs, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "test_eq(to_cpu(dls.after_batch[1].mean[0].squeeze()), tensor(imagenet_stats[0]))\n",
    "test_eq(to_cpu(dls.valid.after_batch[1].mean[0].squeeze()), tensor(imagenet_stats[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you pass a `str` to `arch`, then a [timm](https://huggingface.co/docs/timm/index) model will be created:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = ImageDataLoaders.from_name_re(path, fnames, pat, item_tfms=Resize(224))\n",
    "learn = vision_learner(dls, 'convnext_tiny', loss_func=CrossEntropyLossFlat(), ps=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@delegates(models.unet.DynamicUnet.__init__)\n",
    "def create_unet_model(arch, n_out, img_size, pretrained=True, weights=None, cut=None, n_in=3, **kwargs):\n",
    "    \"Create custom unet architecture\"\n",
    "    meta = model_meta.get(arch, _default_meta)\n",
    "    if parse(torchvision.__version__) >= parse('0.13') and 'weights' in meta:\n",
    "        if weights is not None and not pretrained:\n",
    "            warn(f'{pretrained=} but `weights` are set {weights=}. To randomly initialize set `pretrained=False` & `weights=None`')\n",
    "        model = arch(weights=meta['weights'] if (weights is None and pretrained) else weights)\n",
    "    else:\n",
    "        model = arch(pretrained=pretrained)\n",
    "    body = create_body(model, n_in, pretrained, ifnone(cut, meta['cut']))\n",
    "    model = models.unet.DynamicUnet(body, n_out, img_size, **kwargs)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/vision/learner.py#L248){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### create_unet_model\n",
       "\n",
       ">      create_unet_model (arch, n_out, img_size, pretrained=True, weights=None,\n",
       ">                         cut=None, n_in=3, blur=False, blur_final=True,\n",
       ">                         self_attention=False, y_range=None, last_cross=True,\n",
       ">                         bottle=False, act_cls=<class\n",
       ">                         'torch.nn.modules.activation.ReLU'>, init=<function\n",
       ">                         kaiming_normal_>, norm_type=None)\n",
       "\n",
       "Create custom unet architecture"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/master/fastai/vision/learner.py#L248){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### create_unet_model\n",
       "\n",
       ">      create_unet_model (arch, n_out, img_size, pretrained=True, weights=None,\n",
       ">                         cut=None, n_in=3, blur=False, blur_final=True,\n",
       ">                         self_attention=False, y_range=None, last_cross=True,\n",
       ">                         bottle=False, act_cls=<class\n",
       ">                         'torch.nn.modules.activation.ReLU'>, init=<function\n",
       ">                         kaiming_normal_>, norm_type=None)\n",
       "\n",
       "Create custom unet architecture"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(create_unet_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst = create_unet_model(models.resnet18, 10, (24,24), True, n_in=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@delegates(create_unet_model)\n",
    "def unet_learner(dls, arch, normalize=True, n_out=None, pretrained=True, weights=None, config=None,\n",
    "                 # learner args\n",
    "                 loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,\n",
    "                 model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95), **kwargs):\n",
    "    \"Build a unet learner from `dls` and `arch`\"\n",
    "\n",
    "    if config:\n",
    "        warnings.warn('config param is deprecated. Pass your args directly to unet_learner.')\n",
    "        kwargs = {**config, **kwargs}\n",
    "\n",
    "    meta = model_meta.get(arch, _default_meta)\n",
    "    n_in = kwargs['n_in'] if 'n_in' in kwargs else 3\n",
    "    if normalize: _add_norm(dls, meta, pretrained, n_in)\n",
    "\n",
    "    n_out = ifnone(n_out, get_c(dls))\n",
    "    assert n_out, \"`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`\"\n",
    "    img_size = dls.one_batch()[0].shape[-2:]\n",
    "    assert img_size, \"image size could not be inferred from data\"\n",
    "    model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, weights=weights, **kwargs)\n",
    "\n",
    "    splitter = ifnone(splitter, meta['split'])\n",
    "    learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,\n",
    "                   metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn,\n",
    "                   moms=moms)\n",
    "    if pretrained: learn.freeze()\n",
    "    # keep track of args for loggers\n",
    "    store_attr('arch,normalize,n_out,pretrained', self=learn, **kwargs)\n",
    "    return learn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is built from `arch` using the number of final filters inferred from `dls` if possible (otherwise pass a value to `n_out`). It might be `pretrained` and the architecture is cut and split using the default metadata of the model architecture (this can be customized by passing a `cut` or a `splitter`).\n",
    "\n",
    "If `normalize` and `pretrained` are `True`, this function adds a `Normalization` transform to the `dls` (if there is not already one) using the statistics of the pretrained model. That way, you won't ever forget to normalize your data in transfer learning.\n",
    "\n",
    "All other arguments are passed to `Learner`.\n",
    "\n",
    "<code>unet_learner</code> also supports TorchVision's new multi-weight API via `weights`. See `vision_learner` for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.CAMVID_TINY)\n",
    "fnames = get_image_files(path/'images')\n",
    "def label_func(x): return path/'labels'/f'{x.stem}_P{x.suffix}'\n",
    "codes = np.loadtxt(path/'codes.txt', dtype=str)\n",
    "dls = SegmentationDataLoaders.from_label_func(path, fnames, label_func, codes=codes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = unet_learner(dls, models.resnet34, loss_func=CrossEntropyLossFlat(axis=1), y_range=(0,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "test_ne(learn.cbs, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def create_cnn_model(*args, **kwargs):\n",
    "    \"Deprecated name for `create_vision_model` -- do not use\"\n",
    "    warn(\"`create_cnn_model` has been renamed to `create_vision_model` -- please update your code\")\n",
    "    return create_vision_model(*args, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def cnn_learner(*args, **kwargs):\n",
    "    \"Deprecated name for `vision_learner` -- do not use\"\n",
    "    warn(\"`cnn_learner` has been renamed to `vision_learner` -- please update your code\")\n",
    "    return vision_learner(*args, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Show functions -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def show_results(x:TensorImage, y, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)\n",
    "    ctxs = show_results[object](x, y, samples, outs, ctxs=ctxs, max_n=max_n, **kwargs)\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def show_results(x:TensorImage, y:TensorCategory, samples, outs, ctxs=None, max_n=10, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)\n",
    "    for i in range(2):\n",
    "        ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]\n",
    "    ctxs = [r.show(ctx=c, color='green' if b==r else 'red', **kwargs)\n",
    "            for b,r,c,_ in zip(samples.itemgot(1),outs.itemgot(0),ctxs,range(max_n))]\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def show_results(x:TensorImage, y:TensorMask|TensorPoint|TensorBBox, samples, outs, ctxs=None, max_n=6,\n",
    "                 nrows=None, ncols=1, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize, double=True,\n",
    "                                     title='Target/Prediction')\n",
    "    for i in range(2):\n",
    "        ctxs[::2] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs[::2],range(2*max_n))]\n",
    "    for o in [samples,outs]:\n",
    "        ctxs[1::2] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(o.itemgot(0),ctxs[1::2],range(2*max_n))]\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def show_results(x:TensorImage, y:TensorImage, samples, outs, ctxs=None, max_n=10, figsize=None, **kwargs):\n",
    "    if ctxs is None: ctxs = get_grid(3*min(len(samples), max_n), ncols=3, figsize=figsize, title='Input/Target/Prediction')\n",
    "    for i in range(2):\n",
    "        ctxs[i::3] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs[i::3],range(max_n))]\n",
    "    ctxs[2::3] = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(0),ctxs[2::3],range(max_n))]\n",
    "    return ctxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def plot_top_losses(x: TensorImage, y:TensorCategory, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize, title='Prediction/Actual/Loss/Probability')\n",
    "    for ax,s,o,r,l in zip(axs, samples, outs, raws, losses):\n",
    "        s[0].show(ctx=ax, **kwargs)\n",
    "        ax.set_title(f'{o[0]}/{s[1]} / {l.item():.2f} / {r.max().item():.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def plot_top_losses(x: TensorImage, y:TensorMultiCategory, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize)\n",
    "    for i,(ax,s) in enumerate(zip(axs, samples)): s[0].show(ctx=ax, title=f'Image {i}', **kwargs)\n",
    "    rows = get_empty_df(len(samples))\n",
    "    outs = L(s[1:] + o + (TitledStr(r), TitledFloat(l.item())) for s,o,r,l in zip(samples, outs, raws, losses))\n",
    "    for i,l in enumerate([\"target\", \"predicted\", \"probabilities\", \"loss\"]):\n",
    "        rows = [b.show(ctx=r, label=l, **kwargs) for b,r in zip(outs.itemgot(i),rows)]\n",
    "    display_df(pd.DataFrame(rows))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def plot_top_losses(x:TensorImage, y:TensorMask, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, **kwargs):\n",
    "    axes = get_grid(len(samples)*3, nrows=len(samples), ncols=3, figsize=figsize, flatten=False, title=\"Input | Target | Prediction\")\n",
    "    if axes.ndim == 1: axes = (axes,)\n",
    "    titles = [\"input\", \"target\", \"pred\"]\n",
    "    for axs,s,o,l in zip(axes, samples, outs, losses):\n",
    "        imgs = (s[0], s[1], o[0])\n",
    "        for ax,im,title in zip(axs, imgs, titles):\n",
    "            if title==\"pred\": title += f\"; loss = {l.item():.4f}\"\n",
    "            im.show(ctx=ax, **kwargs)\n",
    "            ax.set_title(title)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
